import os
import math
import numpy as np
from compute_eer import *
import argparse

pi = np.array(np.pi)

def LoadTrials(trials_file):
    '''
    load trials file:
    <enroll-utt-id> <test-utt-id> <target|nontarge>
    '''
    assert os.path.exists(trials_file)
    enroll_id = []
    test_id = []
    target_id = []
    with open(trials_file) as f:
        for line in f:
            spk, utt, is_target = line.strip().split()
            enroll_id.append(spk)
            test_id.append(utt)
            if is_target == "target":
                target_id.append(1)
            else:
                target_id.append(0)
    return np.array(enroll_id), np.array(test_id), np.array(target_id)


def LoadNumUtts(num_utts):
    '''build a hash map (spk->num)'''
    assert os.path.exists(num_utts)
    spk2num = {}
    with open(num_utts) as f:
        for line in f:
            spk, num = line.strip().split()
            spk2num[spk] = int(num)
    # print("Created mapping dict spk2num{}")
    return spk2num


def LoadGlobalMean(f_global_mean):
    f = open(f_global_mean, 'r')
    line = f.readline()
    f.close()
   
    part = line.strip().split()
    global_mean = np.array(part[1:-1], dtype=np.float)
    return global_mean


def LoadPLDA(mdl_plda):
    f = open(mdl_plda, 'r')
    lines = f.readlines()
    f.close()

    dim = len(lines[2].strip().split())
    # print(dim)

    plda_W = []
    for line in lines:
        part = line.strip().split()
        if len(part) == (dim + 3):
            plda_b = np.array(part[2:-1], dtype=np.float)
        elif len(part) == dim:
            plda_W.append(np.array(part, dtype=np.float))
        elif len(part) == (dim + 1):
            plda_W.append(np.array(part[:-1], dtype=np.float))
        elif len(part) == (dim + 2):
            plda_SB = np.array(part[1:-1], dtype=np.float)
        else:
            continue
    return plda_b, np.array(plda_W), plda_SB


def GetNormalizationFactor(transformed_vector, num_utts, plda_SB):
  assert(num_utts > 0)
  dim = len(transformed_vector)
  inv_covar = 1.0 / (1.0 / num_utts + plda_SB)
  dot_prod = np.dot(inv_covar, transformed_vector ** 2)
  return math.sqrt(dim / dot_prod)


def TransformVector(vector, num_utts, plda_W, plda_b, plda_SB, plda_dim, simple_length_norm, normalize_length):
  dim = len(vector)
  normalization_factor = 0.0
  transformed_vector = np.dot(vector - plda_b, plda_W)[:plda_dim]
  if simple_length_norm:
    normalization_factor = math.sqrt(dim) / np.linalg.norm(transformed_vector)
  else:
    normalization_factor = GetNormalizationFactor(transformed_vector, num_utts, plda_SB[:plda_dim])
  if normalize_length:
    transformed_vector = transformed_vector * normalization_factor
  return transformed_vector


def NLScore(enroll_vec, enroll_num, test_vec, SB, SW):
    '''
    normalized likelihood with uncertain means
    SB is the speaker between var
    SW is the speaker within var
    '''
    # uk = enroll_vec * (enroll_num * SB / (enroll_num * SB + SW))
    # pk = ((test_vec - uk)**2 / (SW + SB * SW / (enroll_num * SB + SW))).sum()
    # px = (test_vec**2 / (SW + SB)).sum()

    uk = enroll_vec * (enroll_num * SB / (enroll_num * SB + SW))
    vk = SW + SB * SW / (enroll_num * SB + SW)
    pk = ((test_vec - uk)**2 / vk).sum() + np.log(2 * pi * vk).sum()
    px = (test_vec**2 / (SW + SB)).sum() + np.log(2 * pi * (SW + SB)).sum()

    score = 0.5 * (px - pk)
    return score


def ScoreByTrials(enroll_npz, enroll_num_utts, test_npz, test_trials, score_file, global_mean, plda_W, plda_b, plda_SB, plda_dim, simple_length_norm, normalize_length):
    '''
    compute NL scores by trials
    '''
    # load data
    print("Load data")
    enroll_vectors = np.load(enroll_npz)['vectors']
    enroll_spkers = np.load(enroll_npz)['spker_label']
    enroll_utters = np.load(enroll_npz)['utt_label']

    test_vectors = np.load(test_npz)['vectors']
    test_spkers = np.load(test_npz)['spker_label']
    test_utters = np.load(test_npz)['utt_label']

    # subtract global mean
    print("Centering vectors")
    enroll_vectors = enroll_vectors - global_mean
    test_vectors = test_vectors - global_mean

    # build hashmap enroll_spk -> utters
    enroll_spk2utt = {}
    for idx in range(len(enroll_spkers)):
        spk = enroll_spkers[idx]
        if spk not in enroll_spk2utt:
            enroll_spk2utt[spk] = []
        enroll_spk2utt[spk].append(enroll_vectors[idx])

    # build hashmap test_utt -> utter
    test_spk2utt = {}
    for idx in range(len(test_utters)):
        label = test_utters[idx]
        test_spk2utt[label] = test_vectors[idx]

    # load trials and compute EER
    print("Load test trials")
    enroll_id, test_id, target_id = LoadTrials(test_trials)
    num_utts = LoadNumUtts(enroll_num_utts)

    print("Transform vectors")
    enroll_trans_dict = {}
    for idx in range(len(enroll_spkers)):
        spk = enroll_spkers[idx]
        enroll_vecs = enroll_spk2utt[spk]
        enroll_vec = np.mean(np.array(enroll_vecs), axis=0)
        enroll_num = num_utts[spk]
        enroll_trans_vec = TransformVector(enroll_vec, enroll_num, plda_W, plda_b, plda_SB, plda_dim, simple_length_norm, normalize_length)
        enroll_trans_dict[spk] = enroll_trans_vec

    test_trans_dict = {}
    for idx in range(len(test_utters)):
        label = test_utters[idx]
        test_vec = test_spk2utt[label]
        test_trans_vec = TransformVector(test_vec, 1, plda_W, plda_b, plda_SB, plda_dim, simple_length_norm, normalize_length)
        test_trans_dict[label] = test_trans_vec


    print("Compute NL scoring")
    target_scores = []
    nontarget_scores = []
    foo = open(score_file, 'w')
    for i in range(len(target_id)):
        enroll_num = num_utts[enroll_id[i]]
        enroll_trans_vec = enroll_trans_dict[enroll_id[i]]
        test_trans_vec = test_trans_dict[test_id[i]]
        score = NLScore(enroll_trans_vec, enroll_num, test_trans_vec, plda_SB[:plda_dim], 1)
        foo.write(' '.join([enroll_id[i], test_id[i], str(score)]) + '\n')

        if target_id[i]:
            target_scores.append(score)
        else:
            nontarget_scores.append(score)

    EER, thres = compute_eer(target_scores, nontarget_scores)
    print("EER: {:.3f}% and Threshold: {:.3f}".format(EER*100.0, thres))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--enroll-npz', default='enroll/xvector.npz', help='npz file of enroll vector')
    parser.add_argument(
        '--enroll-num-utts', default='enroll/num_utts.ark', help='mapping file of spker to utter number')
    parser.add_argument(
        '--test-npz', default='test/xvector.npz', help='npz file of test vector')
    parser.add_argument(
        '--trials', default='trials.trl', help='file of test trials')
    parser.add_argument(
        '--score', default='score.foo', help='file of trial scores')
    parser.add_argument(
        '--global-mean', default='mean.vec', help='file of global mean in Kaldi')
    parser.add_argument(
        '--plda', default='plda', help='file of plda model in Kaldi')
    parser.add_argument(
        '--plda-dim', type=int, default=150, help='dim of PLDA')
    parser.add_argument(
        '--simple-length-norm', action='store_true', default=False, help='process simple length norm (2-norm)')
    parser.add_argument(
        '--normalize-length', action='store_true', default=False, help='process length normlization')

    args = parser.parse_args()

    global_mean = LoadGlobalMean(args.global_mean)

    plda_b, plda_W, plda_SB = LoadPLDA(args.plda)
    plda_W = plda_W.T # W transpose
    
    # check model params
    assert(len(plda_b)==len(plda_W))
    assert(len(plda_b)==len(plda_SB))

    if not os.path.exists(os.path.dirname(args.score)):
        os.makedirs(os.path.dirname(args.score))

    ScoreByTrials(args.enroll_npz, args.enroll_num_utts, args.test_npz, args.trials, args.score, global_mean, plda_W, plda_b, plda_SB, args.plda_dim, args.simple_length_norm, args.normalize_length)

